import asyncio
import os
import json
import shutil
import re

from datetime import datetime
from serial import Serial

import numpy as np
import matplotlib.pyplot as plt

from py_pli.pylib import VUnits
from py_pli.pylib import GlobalVar
from py_pli.pylib import send_gc_event

from predefined_tasks.common.helper import send_to_gc

from config_enum import hal_enum as hal_config

from virtualunits.HAL import HAL
from virtualunits.vu_node_application import VUNodeApplication
from virtualunits.vu_measurement_unit import VUMeasurementUnit
from virtualunits.VirtualTemperatureUnit import VirtualTemperatureUnit
from virtualunits.VirtualFanUnit import VUFanControl
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import TriggerSignal
from virtualunits.meas_seq_generator import OutputSignal
from virtualunits.meas_seq_generator import MeasurementChannel
from virtualunits.meas_seq_generator import IntegratorMode
from virtualunits.meas_seq_generator import AnalogControlMode

from urpc.nodefunctions import NodeFunctions
from urpc.measurementfunctions import MeasurementFunctions

from urpc_enum.measurementparameter import MeasurementParameter

from fleming.common.node_io import FMBAnalogOutput
from fleming.common.node_io import EEFAnalogInput
from fleming.common.node_io import EEFAnalogOutput
from fleming.common.node_io import EEFDigitalOutput


hal_unit: HAL = VUnits.instance.hal
eef_unit: VUNodeApplication = hal_unit.nodes['EEFNode']
fmb_unit: VUNodeApplication = hal_unit.nodes['Mainboard']
meas_unit: VUMeasurementUnit = hal_unit.measurementUnit
pmt1_cooling: VirtualTemperatureUnit = hal_unit.pmt_ch1_Cooling
pmt2_cooling: VirtualTemperatureUnit = hal_unit.pmt_ch2_Cooling
uslum_fan: VUFanControl = hal_unit.usLum_Fan

eef_endpoint: NodeFunctions = eef_unit.endpoint
fmb_endpoint: NodeFunctions = fmb_unit.endpoint
meas_endpoint: MeasurementFunctions = meas_unit.endpoint

gc_report_path = hal_unit.configManager.get_config(hal_config.Application.GCReportPath)
gc_images_path = os.path.join(gc_report_path, 'pmt_adjust_images')
gc_archive_path = os.path.join(gc_report_path, 'pmt_adjust_archive')

os.makedirs(gc_report_path, exist_ok=True)
os.makedirs(gc_images_path, exist_ok=True)
os.makedirs(gc_archive_path, exist_ok=True)


async def pmt_adjustment(channel='pmt1', pmt_serial='TEST'):
    
    channel = str(channel).lower() if (channel != '') else 'pmt1'
    pmt_serial = str(pmt_serial).upper() if (pmt_serial != '') else 'TEST'

    if channel not in ['pmt1', 'pmt2', 'pmt3']:
        raise ValueError(f"channel must be 'pmt1', 'pmt2' or 'pmt3'.")
    if not re.match('[A-Z0-9_-]+', pmt_serial):
        raise ValueError(f"pmt_serial must only contain [A-Z0-9_-] characters.")

    report_file = os.path.join(gc_report_path, 'pmt_adjustment.csv')
    
    GlobalVar.set_stop_gc(False)
    await send_to_gc(f"Starting PMT adjustment")

    with open(report_file, 'w') as report:
        report.write(f"PMT_SN: {pmt_serial}\n")
        report.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        report.write('\n')

    await send_to_gc(f"Starting Firmware")
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )
    if GlobalVar.get_stop_gc():
        return f"pmt_adjustment stopped by user"

    ### PMT Stabilization ######################################################

    temperature = 18    # stabilization temperature in °C
    duration    = 30    # stabilization duration in minutes

    await send_to_gc(f"Stabilizing PMT")

    if channel == 'pmt1':
        await pmt1_cooling.InitializeDevice()
        await pmt1_cooling.set_target_temperature(temperature)
        await pmt1_cooling.enable()
    if channel == 'pmt2':
        await pmt2_cooling.InitializeDevice()
        await pmt2_cooling.set_target_temperature(temperature)
        await pmt2_cooling.enable()
    if channel == 'pmt3':
        await uslum_fan.InitializeDevice()
        await uslum_fan.enable()
        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.HTSALPHATECENABLE, 0)
        await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 0)
        
    for minute in range(duration):
        await send_to_gc(f"{duration - minute} min remaining...")
        for second in range(60):
            if GlobalVar.get_stop_gc():
                return f"pmt_adjustment stopped by user"

            await asyncio.sleep(1)
            
    await send_to_gc(f" ")
        
    ### Discriminator Adjustment ###############################################

    pmt3_tec_power   = 0.5  # 0.0 is maximum!!!
    pmt3_laser_power = 1.0  # 1.0 is maximum and current default setting

    await send_to_gc(f"Adjusting Discriminator Level:")

    if (channel == 'pmt3'):
        await eef_endpoint.SetAnalogOutput(EEFAnalogOutput.HTSALPHATEC, pmt3_tec_power)
        await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserPower, pmt3_laser_power)
        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.HTSALPHATECENABLE, 1)
        await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 1)

    results = await pmt_adjust_discriminator(channel, dl_start=0.0, dl_stop=1.0, dl_step=0.001, report_file=report_file)

    if (channel == 'pmt3'):
        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.HTSALPHATECENABLE, 0)
        await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 0)

    if GlobalVar.get_stop_gc():
        return f"pmt_adjustment stopped by user"

    dl = results['dl']

    await send_to_gc(f" ")
    await send_to_gc(f"dl = {dl:.3f}")
    await send_to_gc(f" ")

    ### High Voltage Adjustment ################################################

    await send_to_gc(f"Adjusting High Voltage Setting:")

    if (channel == 'pmt1') or (channel == 'pmt2'):
        results = await pmt_adjust_high_voltage(channel, dl, hv_start=0.400, hv_stop=0.600, hv_step=0.005, report_file=report_file)
    if (channel == 'pmt3'):
        results = await pmt_adjust_high_voltage(channel, dl, hv_start=0.400, hv_stop=0.700, hv_step=0.005, report_file=report_file)
    if GlobalVar.get_stop_gc():
        return f"pmt_adjustment stopped by user"

    hv = results['hv']
    ppr_ns = results['ppr_ns']

    await send_to_gc(f" ")
    await send_to_gc(f"hv = {hv:.3f}, ppr = {ppr_ns:.2f}e-9")
    await send_to_gc(f" ")

    ### Analog Adjustment ######################################################

    if (channel == 'pmt1') or (channel == 'pmt2'):
        
        await send_to_gc(f"Adjusting Analog Scaling:")

        results = await pmt_adjust_analog(channel, dl, hv, ppr_ns, report_file=report_file)
        if GlobalVar.get_stop_gc():
            return f"pmt_adjustment stopped by user"

        als = results['als']
        ahs = results['ahs']

        await send_to_gc(f" ")
        await send_to_gc(f"als = {als}, ahs = {ahs}")
        await send_to_gc(f" ")

    with open(report_file, 'a') as report:
        await send_to_gc(f"Adjusted Parameter:", report=report)
        await send_to_gc(f" ", report=report)

        await send_to_gc(f"DiscriminatorLevel_ = {dl:.3f}", report=report)
        await send_to_gc(f"HighVoltageSetting_ = {hv:.3f}", report=report)
        await send_to_gc(f"Pulse_pair_res_ = {ppr_ns:.2f}e-9", report=report)
        if (channel == 'pmt1') or (channel == 'pmt2'):
            await send_to_gc(f" ", report=report)
            await send_to_gc(f"AnalogCountingEquivalent_ = {als}", report=report)
            await send_to_gc(f"AnalogHighRangeScale_ = {ahs}", report=report)

    shutil.copy(report_file, os.path.join(gc_archive_path, f"PMT_{pmt_serial}_Adjustment.csv"))


async def pmt_adjust_discriminator(channel, dl_start, dl_stop, dl_step, report_file=''):

    window_ms = 100.0
    
    dl_offset = {'pmt1':0.05, 'pmt2':0.05, 'pmt3':0.075}
    dl_width_limit = {'pmt1':0.04, 'pmt2':0.04, 'pmt3':0.06}

    report_file = str(report_file) if (report_file != '') else os.path.join(gc_report_path, f"pmt_adjust_discriminator.csv")

    with open(report_file, 'a') as report:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        report.write(f"pmt_adjust_discriminator(channel={channel}, dl_start={dl_start:.3f}, dl_stop={dl_stop:.3f}, dl_step={dl_step:.3f}) started at {timestamp}\n")
        report.write(f"temperature: {await pmt_get_temperature(channel):.2f}°C\n")
        report.write('\n')

        await pmt_set_hv_enable(channel, 0)

        dl_range = np.arange(dl_start, (dl_stop + 1e-6), dl_step)
        cps = np.zeros_like(dl_range)
        dt = np.zeros_like(dl_range)

        await send_to_gc(f"dl    ; cps        ; dt", report=report)

        for i, dl in enumerate(dl_range):
            await pmt_set_dl(channel, dl)
            await asyncio.sleep(0.1)
            results = await pmt_counting_measurement(window_ms, iterations=1)
            if GlobalVar.get_stop_gc():
                return f"pmt_adjust_discriminator stopped by user"

            cps[i] = results[f"{channel}_cps_mean"]
            dt[i] = results[f"{channel}_dt_mean"]

            await send_to_gc(f"{dl:5.3f} ; {cps[i]:10.0f} ; {dt[i]:10.0f}", report=report)

        report.write('\n')

        cps_max = np.max(cps)
        dt_max = np.max(dt)

        dl_peak = dl_range[cps > (cps_max * 0.9)]
        dt_peak = dl_range[dt > (dt_max * 0.5)]

        if (len(dl_peak) > 0) and (cps_max >= 100):
            dl_center = (np.max(dl_peak) + np.min(dl_peak)) / 2
        elif (len(dt_peak > 0)):
            dl_center = np.max(dt_peak)
        else:
            raise Exception(f"Failed to adjust discriminator level. No peak found in the measurement.")

        dl = np.round((dl_center + dl_offset[channel]), 3)
        
        report.write(f"dl ; {dl:.3f}")
        report.write('\n')

        plot_min = dl_center - dl_offset[channel]
        plot_max = dl_center + 2 * dl_offset[channel]
        plot_select = np.logical_and((dl_range >= plot_min), (dl_range <= plot_max))
        await plot_dl_scan(dl_range[plot_select], cps[plot_select], dt[plot_select], dl, file_name='dl_scan.png')

        #TODO Filter outlier counts when zero before and after.
        dl_noise = dl_range[cps > 0.0]
        dl_width = np.max(dl_noise) - dl_center if (len(dl_noise) > 0) else 0.0

        if dl_width > dl_width_limit[channel]:
            raise Exception(f"Failed to adjust discriminator level. Peak is too wide. (center: {dl_center:.3f}, width: {dl_width:.3f})")

    return {'dl':dl}


async def plot_dl_scan(dl_range, cps, dt, dl, file_name='graph.png'):

    plt.clf()

    plt.subplot(1, 2, 1)
    plt.title('Counts Per Second')
    plt.xlabel('discriminator level')
    plt.yscale('symlog', linthresh=1)
    plt.plot(dl_range, cps)
    plt.axvline(dl, color='r')

    plt.subplot(1, 2, 2)
    plt.title('Dead Time')
    plt.xlabel('discriminator level')
    plt.plot(dl_range, dt)
    plt.axvline(dl, color='r')

    plt.savefig(os.path.join(gc_images_path, file_name))
    await send_gc_event('RefreshGraph', file_name=os.path.join('pmt_adjust_images', file_name))


async def pmt_adjust_high_voltage(channel, dl, hv_start, hv_stop, hv_step, report_file=''):

    linearity_limit = 0.9997

    hv_range = np.arange(hv_start, (hv_stop + 1e-6), hv_step)
    ppr_ns = np.zeros_like(hv_range)
    linearity = np.zeros_like(hv_range)
    sensitivity = np.zeros_like(hv_range)
    rel_error = np.zeros_like(hv_range)
    
    await send_to_gc(f"hv    ; ppr_ns ; linearity ; sensitivity ; rel_error")

    for i, hv in enumerate(hv_range):
        results = await pmt_signal_scan(channel, dl, hv, report_file=report_file)
        if GlobalVar.get_stop_gc():
            return f"pmt_adjust_high_voltage stopped by user"

        ppr_ns[i] = results['ppr_ns']
        linearity[i] = results['linearity']
        sensitivity[i] = results['sensitivity']
        rel_error[i] = results['rel_error']

        await send_to_gc(f"{hv:5.3f} ; {ppr_ns[i]:6.2f} ; {linearity[i]:9.6f} ; {sensitivity[i]:11.3f} ; {rel_error[i]:9.2%}")

    if (np.max(linearity) < linearity_limit):
        raise Exception(f"Failed to adjust high voltage setting. Required linearity not reached. (best linearity: {np.max(linearity):.6f})")

    hv = np.max(hv_range[linearity >= linearity_limit])
    ppr_ns = ppr_ns[hv_range == hv][0]

    return {'hv':hv, 'ppr_ns':ppr_ns}


async def pmt_signal_scan(channel, dl, hv, report_file=''):

    dark_window_ms    = 1000.0
    dark_iterations   = 20
    signal_window_ms  = 1000.0
    signal_iterations = 1

    if (channel == 'pmt1') or (channel == 'pmt2'):
        led = load_led_characteristic(source='fmb', channel='led1', type='green', suffix='')
    if (channel == 'pmt3'):
        led = load_led_characteristic(source='fmb', channel='led3', type='green', suffix='')
    
    report_file = str(report_file) if (report_file != '') else os.path.join(gc_report_path, f"pmt_signal_scan.csv")

    with open(report_file, 'a') as report:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        report.write(f"pmt_signal_scan(channel={channel}, dl={dl:.3f}, hv={hv:.3f}) started at {timestamp}\n")
        report.write(f"temperature: {await pmt_get_temperature(channel):.2f}°C\n")
        report.write('\n')

        await pmt_set_dl(channel, dl)
        await pmt_set_hv(channel, hv)
        await asyncio.sleep(0.2)
        await pmt_set_hv_enable(channel, 1)
        await asyncio.sleep(0.5)

        results = await pmt_counting_measurement(dark_window_ms, dark_iterations)
        if GlobalVar.get_stop_gc():
            return f"pmt_signal_scan stopped by user"

        dark_cps = results[f"{channel}_cps_mean"]
        dark_std = results[f"{channel}_cps_std"]

        report.write(f"dark_cps ; {dark_cps:.6f} ; dark_std ; {dark_std:.6f}\n")
        report.write('\n')

        signal_cps = np.zeros_like(led['current'])

        report.write(f"signal      ; cps\n")

        for i, current in enumerate(led['current']):
            await set_led_current(current, led['source'], led['channel'], led['type'])
            results = await pmt_counting_measurement(signal_window_ms, signal_iterations)
            if GlobalVar.get_stop_gc():
                return f"pmt_signal_scan stopped by user"

            signal_cps[i] = results[f"{channel}_cps_mean"]

            report.write(f"{led['power'][i]:11.2f} ; {signal_cps[i]:11.0f}\n")
            
        report.write('\n')
        
        await set_led_current(0, led['source'], led['channel'], led['type'])
        await pmt_set_hv_enable(channel, 0)

        ppr_ns, linearity = pmt_calculate_ppr_ns(led['power'], signal_cps, optimization='r_squared_min')

        signal_cps = pmt_calculate_correction(signal_cps, ppr_ns)

        if (dark_cps < signal_cps[0]):
            sensitivity = 3 * led['power'][0] * dark_std / (signal_cps[0] - dark_cps) * 1e3
        else:
            sensitivity = 0.0

        # trendline = np.polyfit(led['power'], signal_cps, w=1/signal_cps, deg=1)
        # trendline[-1] = 0.0
        # signal_ref = np.polyval(trendline, led['power'])
        signal_cps_ref = np.mean(signal_cps[:2] / led['power'][:2]) * led['power']
        rel_error = np.abs((signal_cps - signal_cps_ref) / signal_cps_ref)

        # for i in range(len(led['power'])):
        #     await send_to_gc(f"{led['power'][i]} ; {signal_cps[i]} ; {signal_cps_ref[i]} ; {rel_error[i]:%}")

        rel_error = np.max(rel_error)

        report.write(f"ppr_ns ; {ppr_ns:.2f} ; linearity ; {linearity:.6f} ; sensitivity ; {sensitivity:.3f} ; rel_error ; {rel_error:.2%}\n")
        report.write('\n')

        await plot_signal_scan(led['power'], signal_cps, signal_cps_ref, hv, file_name=f"signal_scan_hv_{hv:.3f}.png")

    return {'ppr_ns':ppr_ns, 'linearity':linearity, 'sensitivity':sensitivity, 'rel_error':rel_error}


async def plot_signal_scan(signal, cps, cps_ref, hv, file_name='graph.png'):

    plt.clf()

    plt.title(f"Signal Scan HV={hv:.3f}")
    plt.xlabel('power [nW]')
    plt.ylabel('counts per second')
    plt.xscale('log')
    plt.yscale('log')
    plt.plot(signal, cps_ref, label='reference', color='r')
    plt.plot(signal, cps, label='measured', color='b')
    plt.legend(loc='upper left')

    plt.savefig(os.path.join(gc_images_path, file_name))
    await send_gc_event('RefreshGraph', file_name=os.path.join('pmt_adjust_images', file_name))


async def pmt_adjust_analog(channel, dl, hv, ppr_ns, report_file=''):

    led = load_led_characteristic(source='fmb', channel='led2', type='green', suffix='')
    al_limit = (65536 - 1310) * 0.8     # 80% of the analog low range with ~100 mV offset voltage.
    iterations = 100
    
    report_file = str(report_file) if (report_file != '') else os.path.join(gc_report_path, f"pmt_adjust_analog.csv")

    with open(report_file, 'a') as report:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        report.write(f"pmt_adjust_analog(channel={channel}, dl={dl:.3f}, hv={hv:.3f}, ppr_ns={ppr_ns:.2}) started at {timestamp}\n")
        report.write(f"temperature: {await pmt_get_temperature(channel):.2f}°C\n")
        report.write('\n')

        await pmt_set_dl(channel, dl)
        await pmt_set_hv(channel, hv)
        await asyncio.sleep(0.2)
        await pmt_set_hv_enable(channel, 1)
        await asyncio.sleep(0.5)

        await set_led_current(np.max(led['current']), led['source'], led['channel'], led['type'])
        
        window_ms = 1.0
        results = await pmt_analog_measurement(window_ms, iterations)

        if (results[f"{channel}_al_mean"] > al_limit):
            window_ms = 0.1
            results = await pmt_analog_measurement(window_ms, iterations)

        if (results[f"{channel}_al_mean"] > al_limit):
            al_max = results[f"{channel}_al_mean"]
            raise Exception(f"Failed to adjust analog scaling factors. Signal is too bright. (max analog low: {al_max:.0f})")

        window_ms = np.round(al_limit / results[f"{channel}_al_mean"] * window_ms, 3)

        report.write(f"window_ms ;  {window_ms:.3f}\n")
        report.write('\n')
        
        cnt = np.zeros_like(led['current'])
        al = np.zeros_like(led['current'])
        ah = np.zeros_like(led['current'])
        
        await send_to_gc(f"signal      ; counting    ; analog_low  ; analog_high", report=report)

        for i, current in enumerate(led['current']):
            await set_led_current(current, led['source'], led['channel'], led['type'])
            results = await pmt_analog_measurement(window_ms, iterations)
            if GlobalVar.get_stop_gc():
                return f"pmt_adjust_analog stopped by user"

            cnt[i] = results[f"{channel}_cnt_mean"]
            al[i] = results[f"{channel}_al_mean"]
            ah[i] = results[f"{channel}_ah_mean"]
            
            await send_to_gc(f"{led['power'][i]:11.2f} ; {cnt[i]:11.0f} ; {al[i]:11.0f} ; {ah[i]:11.0f}", report=report)
            
        report.write('\n')
        
        await set_led_current(0, led['source'], led['channel'], led['type'])
        await pmt_set_hv_enable(channel, 0)

        # Count Rate Correction
        ppr_ms = ppr_ns * 1e-6
        cnt = cnt / (1 - cnt * ppr_ms / window_ms)

        als = np.sum(cnt * al) / np.sum(al ** 2)
        ahs = np.sum(al * ah) / np.sum(ah ** 2)
        
        report.write(f"als ; {als} ; ahs ; {ahs}\n")
        report.write('\n')

        await plot_analog_scan(led['power'], cnt, (al * als), (ah * ahs * als), file_name='analog_scan.png')

    return {'als':als, 'ahs':ahs}


async def plot_analog_scan(signal, cnt, al, ah, file_name='graph.png'):

    plt.clf()

    plt.title(f"Analog Adjustment")
    plt.xlabel('power [nW]')
    plt.ylabel('counts')
    plt.plot(signal, ah, label='analog_high', color='g')
    plt.plot(signal, al, label='analog_low', color='r')
    plt.plot(signal, cnt, label='counting', color='b')
    plt.legend(loc='upper left')

    plt.savefig(os.path.join(gc_images_path, file_name))
    await send_gc_event('RefreshGraph', file_name=os.path.join('pmt_adjust_images', file_name))


def pmt_calculate_ppr_ns(signal, counts, window_ms=1000, optimization='r_squared_min', optimization_arg=''):
    ppr_ns_max = 25.0
    ppr_ns = 0.0

    linearity = np.ones_like(signal)
    for i in range(2, len(signal)):
        linearity[i] = pmt_calculate_linearity(signal[:(i + 1)], counts[:(i + 1)], ppr_ns, window_ms)

    precision = 2
    for step_ns in [1.0 / 10**digit for digit in range(precision + 1)]:

        if (ppr_ns >= (step_ns * 10.0)):
            # Take one step back for cases like this: Optimal ppr is 11.5 and r²(11.0) < r²(12.0) > r²(13.0).
            ppr_ns = round((ppr_ns - step_ns * 10.0), precision)
            for i in range(2, len(linearity)):
                linearity[i] = pmt_calculate_linearity(signal[:(i + 1)], counts[:(i + 1)], ppr_ns, window_ms)

        stop = False
        while not stop:
            temp_ppr_ns = round((ppr_ns + step_ns), precision)
            temp_linearity = np.ones_like(signal)
            for i in range(2, len(signal)):
                temp_linearity[i] = pmt_calculate_linearity(signal[:(i + 1)], counts[:(i + 1)], temp_ppr_ns, window_ms)

            if (optimization == 'r_squared'):
                index = int(optimization_arg) if (optimization_arg != '') else (len(signal) - 1)
                if (index < 2) or (index >= len(signal)):
                    raise ValueError(f"Invalid optimization signal index.")
                stop = temp_linearity[index] < linearity[index]
            elif (optimization == 'r_squared_min'):
                stop = np.min(temp_linearity) < np.min(linearity)
            else:
                raise ValueError(f"Invalid optimization option.")

            if not stop:
                ppr_ns = temp_ppr_ns
                linearity = temp_linearity

            if (ppr_ns > ppr_ns_max):
                # Failed to calculate pulse pair resolution
                return (0.0, 0.0)

    return (ppr_ns, np.min(linearity))


def pmt_calculate_linearity(signal, counts, ppr_ns, window_ms=1000):

    counts = pmt_calculate_correction(counts, ppr_ns, window_ms)

    correlation_coef = np.corrcoef(signal, counts)
    r_value = correlation_coef[0,1]
    r_squared = r_value**2

    return r_squared


def pmt_calculate_linearity_r_squared(signal, counts, ppr_ns, window_ms=1000):

    counts = pmt_calculate_correction(counts, ppr_ns, window_ms)

    r_squared = np.ones_like(signal)

    for i in range(2, len(signal)):
        correlation_coef = np.corrcoef(signal[:(i + 1)], counts[:(i + 1)])
        r_value = correlation_coef[0,1]
        r_squared[i] = r_value**2

    return r_squared


def pmt_calculate_linearity_rel_error(signal, counts, ppr_ns, window_ms=1000):

    counts = pmt_calculate_correction(counts, ppr_ns, window_ms)
    
    counts_ref = np.mean(counts[:2] / signal[:2]) * signal
    rel_error = np.abs((counts - counts_ref) / counts_ref)

    return rel_error


def pmt_calculate_linearity_offset(signal, counts, ppr_ns, window_ms=1000):

    counts = pmt_calculate_correction(counts, ppr_ns, window_ms)

    trendline = np.polyfit(signal, counts, deg=1)
    counts_ref = np.polyval(trendline, signal)

    offset = np.abs(counts - counts_ref) / signal

    return offset


def pmt_calculate_correction(counts, ppr_ns, window_ms=1000):

    counts = np.array(counts)
    counts = counts / (1 - counts * ppr_ns * 1e-6 / window_ms)

    return counts


async def pmt_get_temperature(channel):
    temperature = 0.0
    if channel == 'pmt1':
        temperature = await pmt1_cooling.get_feedback_value()
    if channel == 'pmt2':
        temperature = await pmt2_cooling.get_feedback_value()
    if channel == 'pmt3':
        temperature = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.HTSALPHATEMPIN))[0] * -31.81 + 41.992
    return temperature


async def pmt_set_dl(channel, dl):
    if channel == 'pmt1':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT1DiscriminatorLevel, dl)
    if channel == 'pmt2':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT2DiscriminatorLevel, dl)
    if channel == 'pmt3':
        await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMDiscriminatorLevel, dl)


async def pmt_set_hv(channel, hv):
    if channel == 'pmt1':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT1HighVoltageSetting, hv)
    if channel == 'pmt2':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT2HighVoltageSetting, hv)
    if channel == 'pmt3':
        await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageSetting, hv)


async def pmt_set_hv_enable(channel, enable):
    if channel == 'pmt1':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT1HighVoltageEnable, enable)
    if channel == 'pmt2':
        await meas_endpoint.SetParameter(MeasurementParameter.PMT2HighVoltageEnable, enable)
    if channel == 'pmt3':
        await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageEnable, enable)


async def pmt_counting_measurement(window_ms=1000.0, iterations=10):
    
    op_id = 'pmt_counting_measurement'
    meas_unit.ClearOperations()
    await load_pmt_counting_measurement(op_id, window_ms)
    
    pmt1_cps = np.zeros(iterations)
    pmt1_dt = np.zeros(iterations)
    pmt2_cps = np.zeros(iterations)
    pmt2_dt = np.zeros(iterations)
    pmt3_cps = np.zeros(iterations)
    pmt3_dt = np.zeros(iterations)

    for i in range(iterations):
        if GlobalVar.get_stop_gc():
            return f"pmt_counting_measurement stopped by user"
        
        await meas_unit.ExecuteMeasurement(op_id)
        results = await meas_unit.ReadMeasurementValues(op_id)

        pmt1_cps[i] = (results[0]  + (results[1]  << 32)) / window_ms * 1000.0
        pmt1_dt[i]  = (results[2]  + (results[3]  << 32)) / window_ms * 1000.0
        pmt2_cps[i] = (results[6]  + (results[7]  << 32)) / window_ms * 1000.0
        pmt2_dt[i]  = (results[8]  + (results[9]  << 32)) / window_ms * 1000.0
        pmt3_cps[i] = (results[12] + (results[13] << 32)) / window_ms * 1000.0
        pmt3_dt[i]  = (results[14] + (results[15] << 32)) / window_ms * 1000.0

    results = {}

    results['pmt1_cps_mean'] = np.mean(pmt1_cps)
    results['pmt1_dt_mean'] = np.mean(pmt1_dt)
    results['pmt2_cps_mean'] = np.mean(pmt2_cps)
    results['pmt2_dt_mean'] = np.mean(pmt2_dt)
    results['pmt3_cps_mean'] = np.mean(pmt3_cps)
    results['pmt3_dt_mean'] = np.mean(pmt3_dt)
    
    results['pmt1_cps_std'] = np.std(pmt1_cps)
    results['pmt1_dt_std'] = np.std(pmt1_dt)
    results['pmt2_cps_std'] = np.std(pmt2_cps)
    results['pmt2_dt_std'] = np.std(pmt2_dt)
    results['pmt3_cps_std'] = np.std(pmt3_cps)
    results['pmt3_dt_std'] = np.std(pmt3_dt)

    return results


async def load_pmt_counting_measurement(op_id, window_ms):
    if (window_ms < 0.001):
        raise ValueError(f"window_ms must be greater or equal to 0.001 ms")

    window_us = round(window_ms * 1000)

    window_coarse, window_fine = divmod(window_us, 65536)

    hv_gate_delay = 1000000     #  10 ms
    pre_cnt_window = 100        #   1 us

    seq_gen = meas_seq_generator()

    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 0)  # pmt1_cr_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 1)  # pmt1_cr_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 2)  # pmt1_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 3)  # pmt1_dt_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 4)  # pmt1_alr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 5)  # pmt1_ahr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 6)  # pmt2_cr_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 7)  # pmt2_cr_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 8)  # pmt2_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 9)  # pmt2_dt_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=10)  # pmt2_alr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=11)  # pmt2_ahr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=12)  # pmt3_cr_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=13)  # pmt3_cr_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=14)  # pmt3_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=15)  # pmt3_dt_msb

    seq_gen.TimerWaitAndRestart(hv_gate_delay)
    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)

    seq_gen.TimerWaitAndRestart(pre_cnt_window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    if window_coarse > 0:
        seq_gen.Loop(window_coarse)
        seq_gen.Loop(65536)
        seq_gen.TimerWaitAndRestart(pre_cnt_window)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()
        seq_gen.LoopEnd()
    if window_fine > 0:
        seq_gen.Loop(window_fine)
        seq_gen.TimerWaitAndRestart(pre_cnt_window)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()

    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=0)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=2)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT2, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=6)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT2, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=8)
    seq_gen.GetPulseCounterResult(MeasurementChannel.US_LUM, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=12)
    seq_gen.GetPulseCounterResult(MeasurementChannel.US_LUM, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=14)

    seq_gen.ResetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    seq_gen.Stop(0)

    meas_unit.resultAddresses[op_id] = range(0, 16)
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)


async def pmt_analog_measurement(window_ms=1.0, iterations=100):
    
    op_id = 'pmt_analog_measurement'
    meas_unit.ClearOperations()
    await load_pmt_analog_measurement(op_id, window_ms, iterations)
        
    await meas_unit.ExecuteMeasurement(op_id)
    results = await meas_unit.ReadMeasurementValues(op_id)

    pmt1_cnt = results[0] + (results[1] << 32)
    pmt1_dt  = results[2] + (results[3] << 32)
    pmt1_al  = results[4]
    pmt1_ah  = results[5]
    pmt2_cnt = results[6] + (results[7] << 32)
    pmt2_dt  = results[8] + (results[9] << 32)
    pmt2_al  = results[10]
    pmt2_ah  = results[11]

    results = {}

    results['pmt1_cnt_mean'] = pmt1_cnt / iterations
    results['pmt1_dt_mean'] = pmt1_dt / iterations
    results['pmt1_al_mean'] = pmt1_al / iterations
    results['pmt1_ah_mean'] = pmt1_ah / iterations
    results['pmt2_cnt_mean'] = pmt2_cnt / iterations
    results['pmt2_dt_mean'] = pmt2_dt / iterations
    results['pmt2_al_mean'] = pmt2_al / iterations
    results['pmt2_ah_mean'] = pmt2_ah / iterations

    return results


async def load_pmt_analog_measurement(op_id, window_ms=1.0, window_count=100):
    if (window_ms < 0.001):
        raise ValueError(f"window_ms must be greater or equal to 0.001 ms")
    if (window_count < 1) or (window_count > 65536):
        raise ValueError(f"window_count must be in the range [1, 65536]")

    window_us = round(window_ms * 1000)

    window_coarse, window_fine = divmod(window_us, 65536)

    full_reset_delay = 40000    # 400 us
    pre_cnt_window = 100        #   1 us
    conversion_delay = 1200     #  12 us
    switch_delay = 25           # 250 ns
    fixed_range = 2000          #  20 us
    reset_switch_delay = 2000   #  20 us
    input_gate_delay = 100      #   1 us

    seq_gen = meas_seq_generator()

    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 0)  # pmt1_cr_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 1)  # pmt1_cr_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 2)  # pmt1_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 3)  # pmt1_dt_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 4)  # pmt1_alr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 5)  # pmt1_ahr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 6)  # pmt2_cr_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 7)  # pmt2_cr_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 8)  # pmt2_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr= 9)  # pmt2_dt_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=10)  # pmt2_alr
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=11)  # pmt2_ahr

    seq_gen.ResetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)
    seq_gen.SetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)

    seq_gen.Loop(window_count)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.full_offset_reset, pmt2=AnalogControlMode.full_offset_reset)

    seq_gen.TimerWaitAndRestart(full_reset_delay)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.full_reset, pmt2=IntegratorMode.full_reset)
    seq_gen.ResetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)

    seq_gen.TimerWaitAndRestart(conversion_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)

    seq_gen.TimerWaitAndRestart(switch_delay)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.low_range_reset, pmt2=IntegratorMode.low_range_reset)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.read_offset, pmt2=AnalogControlMode.read_offset)

    seq_gen.TimerWaitAndRestart(reset_switch_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.integrate_in_low_range, pmt2=IntegratorMode.integrate_in_low_range)

    seq_gen.TimerWaitAndRestart(input_gate_delay)
    seq_gen.SetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)

    seq_gen.TimerWaitAndRestart(pre_cnt_window)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    if window_coarse > 0:
        seq_gen.Loop(window_coarse)
        seq_gen.Loop(65536)
        seq_gen.TimerWaitAndRestart(pre_cnt_window)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()
        seq_gen.LoopEnd()
    if window_fine > 0:
        seq_gen.Loop(window_fine)
        seq_gen.TimerWaitAndRestart(pre_cnt_window)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT1, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.PulseCounterControl(MeasurementChannel.PMT2, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()

    seq_gen.ResetSignals(OutputSignal.InputGatePMT1 | OutputSignal.InputGatePMT2)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.read_offset, pmt2=AnalogControlMode.read_offset)

    seq_gen.TimerWaitAndRestart(fixed_range)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.integrate_in_high_range, pmt2=IntegratorMode.integrate_in_high_range)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT1, isRelativeAddr=False, ignoreRange=False, isHiRange=False, addResult=True, dword=False, addrPos=0, resultPos=4)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT2, isRelativeAddr=False, ignoreRange=False, isHiRange=False, addResult=True, dword=False, addrPos=0, resultPos=10)
    
    seq_gen.TimerWaitAndRestart(conversion_delay)
    seq_gen.SetTriggerOutput(TriggerSignal.SamplePMT1 | TriggerSignal.SamplePMT2)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT1, isRelativeAddr=False, ignoreRange=False, isHiRange=True, addResult=True, dword=False, addrPos=0, resultPos=5)
    seq_gen.GetAnalogResult(MeasurementChannel.PMT2, isRelativeAddr=False, ignoreRange=False, isHiRange=True, addResult=True, dword=False, addrPos=0, resultPos=11)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=0)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT1, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=2)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT2, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=6)
    seq_gen.GetPulseCounterResult(MeasurementChannel.PMT2, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=8)

    seq_gen.LoopEnd()

    seq_gen.ResetSignals(OutputSignal.HVGatePMT1 | OutputSignal.HVGatePMT2)
    seq_gen.SetAnalogControl(pmt1=AnalogControlMode.full_offset_reset, pmt2=AnalogControlMode.full_offset_reset)
    seq_gen.SetIntegratorMode(pmt1=IntegratorMode.full_reset, pmt2=IntegratorMode.full_reset)
    seq_gen.Stop(0)

    meas_unit.resultAddresses[op_id] = range(0, 12)
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)


def load_led_characteristic(source='fmb', channel='led1', type='green', suffix=''):
    calibration_dir = f"{os.path.dirname(__file__)}/pmt_adjust_calibration"
    calibration_file = f"{source}_{channel}_{type}{suffix}.json"
    with open(f"{calibration_dir}/{calibration_file}", 'r') as file:
        power_scan = json.load(file)
        current = np.array(list(power_scan.keys()), dtype=float)
        power = np.array(list(power_scan.values()), dtype=float)

    return {'source':source, 'channel':channel, 'type':type, 'current':current, 'power':power}


async def set_led_current(current, source='smu', channel='led1', led_type='red'):
    if source.startswith('smu'):
        return await smu_set_led_current(current, source, channel, led_type)
    if source.startswith('fmb'):
        return await fmb_set_led_current(current, source, channel, led_type)
    
    raise Exception(f"Invalid source: {source}")


led_compliance_voltage = {'red':2.0, 'green':2.9, 'blue':3.1}

async def smu_set_led_current(current=1, source='smu', channel='led1', led_type='green'):
    if not source.startswith('smu'):
        raise ValueError(f"Invalid source: {source}")
    
    from pymeasure.instruments.keithley import Keithley2450

    led_color = led_type.split('_')[0]

    # VISA address: USB[board]::manufacturer ID::model code::serial number[::USB interface number][::INSTR]
    smu = Keithley2450('USB0::1510::9296::?*::INSTR')
    smu.apply_current()
    smu.measure_voltage()
    smu.compliance_voltage = led_compliance_voltage[led_color]
    current = round(current)
    if (current > 0) and (current <= 30000):
        smu.source_current = current * 1e-6
        smu.enable_source()
    else:
        smu.shutdown()

    await asyncio.sleep(0.2)
    return current


fmb_led_channel = {
    'led1' : {'bc':FMBAnalogOutput.BCG, 'gs':FMBAnalogOutput.GSG1},
    'led2' : {'bc':FMBAnalogOutput.BCR, 'gs':FMBAnalogOutput.GSR3},
    'led3' : {'bc':FMBAnalogOutput.BCG, 'gs':FMBAnalogOutput.GSG2},
    'led4' : {'bc':FMBAnalogOutput.BCB, 'gs':FMBAnalogOutput.GSB3},
}

fmb_full_scale_current = {
    'fmb'    :   127,   # Brightness Value [0, 127]
    'fmb#7'  :  2129,   # Measured with Keithley Multimeter (FMB#7,  24k3, Rev 2)
    'fmb#12' :  2084,   # Measured with Keithley Multimeter (FMB#12, 24k3, Rev 2)
    'fmb#15' : 11084,   # Measured with Keithley Multimeter (FMB#15,  4k7, Rev 2)
}

async def fmb_set_led_current(current=17, source='fmb', channel='led1', led_type='green'):
    bc = fmb_led_channel[channel]['bc']
    gs = fmb_led_channel[channel]['gs']
    full_scale_current = fmb_full_scale_current[source]
    brightness = current / full_scale_current
    current = round(min(int(brightness * 128), 127) / 127 * full_scale_current)
    if (current > 0):
        await fmb_endpoint.SetAnalogOutput(bc, brightness)
        await fmb_endpoint.SetAnalogOutput(gs, 1.0)
    else:
        await fmb_endpoint.SetAnalogOutput(bc, 0.0)
        await fmb_endpoint.SetAnalogOutput(gs, 0.0)

    await asyncio.sleep(0.2)
    return current


async def opm_led_power_scan(current_start=1, current_stop=127, current_steps=127, logspace=False, source='fmb', channel='led3', led_type='green'):
    on_delay  = 5   # Switch the LED on and wait before measuering the power
    off_delay = 5   # Switch the LED off after the measurement

    if not logspace:
        current_range = np.linspace(current_start, current_stop, current_steps)
    else:
        current_range = np.logspace(np.log10(current_start), np.log10(current_stop), current_steps)

    opm_info = opm_init(led_type)
    await send_to_gc(f"{opm_info}", log=True)

    await send_to_gc(f"current ; power", log=True)

    power_scan = {}
    for current in current_range:
        await set_led_current(0, source, channel, led_type)
        await asyncio.sleep(off_delay)
        current = await set_led_current(current, source, channel, led_type)
        await asyncio.sleep(on_delay)
        power_scan[current] = np.round(opm_get_power() * 1e9, decimals=2)
        
        await send_to_gc(f"{current:7.0f} ; {power_scan[current]:11.2f}", log=True)

    await set_led_current(0, source, channel, led_type)

    filter = opm_info['filter']
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    calibration_dir = f"{os.path.dirname(__file__)}/pmt_adjust_calibration"
    calibration_file = f"{source}_{channel}_{led_type}_{timestamp}_filter_{filter}.json"
    os.makedirs(calibration_dir, exist_ok=True)
    with open(f"{calibration_dir}/{calibration_file}", 'w') as file:
        json.dump(power_scan, file, indent=4)

    return power_scan


led_wavelength = {'red':635, 'green':525, 'blue':465}

opm_serial_port = '/dev/ttyUSB0'
opm_bautrate = 115200

def opm_init(wavelength='red'):
    if wavelength in led_wavelength:
        wavelength = led_wavelength[wavelength]

    with Serial(opm_serial_port, opm_bautrate, timeout=2) as serial:
        idn = opm_send_command(serial, f"II")
        if (idn.split(' ')[0] != 'STBR'):
            raise Exception(f"Failed to read power meter identification: {idn}")
        
        opm_send_command(serial, f"FP")                 # Power Measurement
        opm_send_command(serial, f"WN -1")              # Autorange
        opm_send_command(serial, f"WL {wavelength}")    # Wavelength
        opm_send_command(serial, f"AQ 1")               # No Average

        filter = opm_send_command(serial, f"FQ").split(' ')[2]

    return {'idn':idn, 'wavelength':wavelength, 'filter':filter}


def opm_get_power() -> float:
    with Serial(opm_serial_port, opm_bautrate, timeout=2) as serial:
        response = opm_send_command(serial, f"SP")
        if response.lower() == 'over':
            raise Exception(f"Overrange error")
        power = float(response)
        return power


def opm_send_command(serial: Serial, command: str) -> str:
    serial.write(f'${command}\r\n'.encode('utf-8'))
    response = serial.readline().decode('utf-8')
    if not response.startswith('*'):
        raise Exception(f"Invalid power meter command: {command} -> {response}")
    
    return response[1:-2]


async def pmt_dark_signal(channel='pmt3', duration_s=60):

    dl = {'pmt1':0.260, 'pmt2':0.260, 'pmt3':0.375}
    hv = {'pmt1':0.475, 'pmt2':0.475, 'pmt3':0.535}
    dark_window_ms  = 1000
    dark_iterations = 1

    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )
    await pmt_set_dl(channel, dl[channel])
    await pmt_set_hv(channel, hv[channel])
    await asyncio.sleep(0.2)
    await pmt_set_hv_enable(channel, 1)
    await asyncio.sleep(0.5)

    for second in range(duration_s):
        results = await pmt_counting_measurement(dark_window_ms, dark_iterations)
        dark_cps = results[f"{channel}_cps_mean"]
        dark_std = results[f"{channel}_cps_std"]

        await send_to_gc(f"{second:3d} dark_cps: {dark_cps:6.0f}, dark_std: {dark_std:6.0f}", log=True)
        
    await pmt_set_hv_enable(channel, 0)

